{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Iris classification with scikit-learn\n", "\n", "Here we use the well-known Iris species dataset to illustrate how SHAP can explain the output of many different model types, from k-nearest neighbors, to neural networks. This dataset is very small, with only a 150 samples. We use a random set of 130 for training and 20 for testing the models. Because this is a small dataset with only a few features we use the entire training dataset for the background. In problems with more features we would want to pass only the median of the training dataset, or weighted k-medians. While we only have a few samples, the prediction problem is fairly easy and all methods acheive perfect accuracy. What's interesting is how different methods sometimes rely on different sets of features for their predictions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import time\n", "\n", "import numpy as np\n", "import sklearn\n", "from sklearn.model_selection import train_test_split\n", "\n", "import shap\n", "\n", "X, y = shap.datasets.iris()\n", "X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size=0.2, random_state=0)\n", "\n", "# rather than use the whole training set to estimate expected values, we could summarize with\n", "# a set of weighted kmeans, each weighted by the number of points they represent. But this dataset\n", "# is so small we don't worry about it\n", "# X_train_summary = shap.kmeans(X_train, 50)\n", "\n", "\n", "def print_accuracy(f):\n", " print(f\"Accuracy = {100 * np.sum(f(X_test) == Y_test) / len(Y_test)}%\")\n", " time.sleep(0.5) # to let the print get out before any progress bars\n", "\n", "\n", "shap.initjs()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## K-nearest neighbors" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 96.66666666666667%\n" ] } ], "source": [ "knn = sklearn.neighbors.KNeighborsClassifier()\n", "knn.fit(X_train, Y_train)\n", "\n", "print_accuracy(knn.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explain a single prediction from the test set" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n" ] }, { "data": { "text/html": [ "\n", "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "explainer = shap.Explainer(knn.predict_proba, X_train)\n", "shap_values = explainer(X_test)\n", "\n", "# visualize the first prediction's explanation for the 'setosa' class\n", "shap.plots.force(shap_values[0, ..., 0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explain all the predictions in the test set" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2c0ec8442f734764bab3e6c5a97f98d8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# visualize all the predictions' explanations for the 'setosa' class\n", "shap.plots.force(shap_values[..., 0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Support vector machine with a linear kernel" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 100.0%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ea8f8a814e44baba56bd6649eeea0f6", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "svc_linear = sklearn.svm.SVC(kernel=\"linear\", probability=True)\n", "svc_linear.fit(X_train, Y_train)\n", "print_accuracy(svc_linear.predict)\n", "\n", "# explain all the predictions in the test set\n", "explainer = shap.Explainer(svc_linear.predict_proba, X_train)\n", "shap_values = explainer(X_test)\n", "\n", "# visualize the training set predictions\n", "shap.plots.force(shap_values[..., 0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Support vector machine with a radial basis function kernel" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 100.0%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d77859c2b23f410c9d0aea1b5f474516", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "svc_rbf = sklearn.svm.SVC(kernel=\"rbf\", probability=True)\n", "svc_rbf.fit(X_train, Y_train)\n", "print_accuracy(svc_rbf.predict)\n", "\n", "# explain all the predictions in the test set\n", "explainer = shap.Explainer(svc_rbf.predict_proba, X_train)\n", "shap_values = explainer(X_test)\n", "\n", "# visualize the training set predictions\n", "shap.plots.force(shap_values[..., 0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Logistic regression" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 100.0%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3a63d0290d4744c6b57067acba3f9d77", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linear_lr = sklearn.linear_model.LogisticRegression(solver=\"newton-cg\")\n", "linear_lr.fit(X_train, Y_train)\n", "print_accuracy(linear_lr.predict)\n", "\n", "# explain all the predictions in the test set\n", "explainer = shap.Explainer(linear_lr.predict_proba, X_train)\n", "shap_values = explainer(X_test)\n", "\n", "# visualize the training set predictions\n", "shap.plots.force(shap_values[..., 0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decision tree" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 100.0%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "70699397a390482f8a42a149bdb60e1d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sklearn.tree\n", "\n", "dtree = sklearn.tree.DecisionTreeClassifier(min_samples_split=2)\n", "dtree.fit(X_train, Y_train)\n", "print_accuracy(dtree.predict)\n", "\n", "# explain all the predictions in the test set\n", "explainer = shap.Explainer(dtree.predict_proba, X_train)\n", "shap_values = explainer(X_test)\n", "\n", "# visualize the training set predictions\n", "shap.plots.force(shap_values[..., 0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Random forest" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 100.0%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "21620db188a64c9185448d836fa48e45", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "\n", "rforest = RandomForestClassifier(n_estimators=100, max_depth=None, min_samples_split=2, random_state=0)\n", "rforest.fit(X_train, Y_train)\n", "print_accuracy(rforest.predict)\n", "\n", "# explain all the predictions in the test set\n", "explainer = shap.Explainer(rforest.predict_proba, X_train)\n", "shap_values = explainer(X_test)\n", "\n", "# visualize the training set predictions\n", "shap.plots.force(shap_values[..., 0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Neural network" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 96.66666666666667%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using 120 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "893fe22dc360430980bab9d2e5abac4f", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.neural_network import MLPClassifier\n", "\n", "nn = MLPClassifier(solver=\"lbfgs\", alpha=1e-1, hidden_layer_sizes=(5, 2), random_state=0)\n", "nn.fit(X_train, Y_train)\n", "print_accuracy(nn.predict)\n", "\n", "# explain all the predictions in the test set\n", "explainer = shap.Explainer(nn.predict_proba, X_train)\n", "shap_values = explainer(X_test)\n", "\n", "# visualize the training set predictions\n", "shap.plots.force(shap_values[..., 0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary Plot\n", "\n", "While force plots are great for seeing how each feature contributes to a single prediction, a beeswarm plot can show the relative importance of features across the entire test set. Here we show the beeswarm plot for the 'setosa' species using the Random Forest model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer = shap.Explainer(rforest.predict_proba, X_train)\n", "shap_values = explainer(X_test)\n", "\n", "# summarize the effects of all the features\n", "shap.plots.beeswarm(shap_values[..., 0])" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 1 }